import pandas as pd
import numpy as np

from scipy import stats
import statsmodels.api as sm
from statsmodels.stats.oneway import anova_oneway
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import os

# =========================
# CONFIGURACIÓN
# =========================

CSV_PATH = "ud_metrics_by_text.csv"
OUT_DIR = "salidas_estadistica"

SYN_VARS = [
    "ccomp_per1000",
    "xcomp_per1000",
    "advcl_per1000",
    "acl_relcl_per1000",
    "cc_per1000",
    "conj_per1000"
]

STRUCT_VARS = [
    "tokens",
    "sentences",
    "tokens_per_sentence"
]

ALL_VARS = STRUCT_VARS + SYN_VARS

def ensure_outdir():
    os.makedirs(OUT_DIR, exist_ok=True)


def hedges_g(x, y):
    nx, ny = len(x), len(y)
    sx, sy = np.var(x, ddof=1), np.var(y, ddof=1)
    s_pooled = np.sqrt(((nx - 1)*sx + (ny - 1)*sy) / (nx + ny - 2))
    if s_pooled == 0:
        return 0.0
    d = (np.mean(x) - np.mean(y)) / s_pooled
    J = 1 - (3 / (4*(nx + ny) - 9))
    return float(J * d)


def omega_squared_from_anova(ss_between, ss_within, df_between, ms_within, ss_total):
    return float((ss_between - df_between * ms_within) / (ss_total + ms_within))

def run_univariate(df):

    results = []
    posthoc_rows = []
    assume_rows = []

    models = df["model"].unique()

    for var in ALL_VARS:

        groups = []
        shapiro_p = {}

        for m in models:
            vals = df.loc[df["model"] == m, var].dropna().values
            groups.append(vals)

            if len(vals) >= 3:
                shapiro_p[m] = stats.shapiro(vals).pvalue
            else:
                shapiro_p[m] = np.nan

        lev_p = stats.levene(*groups, center="median").pvalue

        assume_rows.append({
            "variable": var,
            "levene_p": lev_p,
            **{f"shapiro_p_{m}": shapiro_p.get(m, np.nan) for m in models}
        })

        # Welch ANOVA (robusto a varianzas desiguales)
        welch = anova_oneway(groups, use_var="unequal", welch_correction=True)

        # Kruskal (robustez no paramétrica)
        kruskal = stats.kruskal(*groups)

        # ANOVA clásico solo para calcular omega²
        lm = sm.formula.ols(f"{var} ~ C(model)", data=df).fit()
        anova_tbl = sm.stats.anova_lm(lm, typ=2)

        ss_between = anova_tbl.loc["C(model)", "sum_sq"]
        ss_within = anova_tbl.loc["Residual", "sum_sq"]
        df_between = anova_tbl.loc["C(model)", "df"]
        df_within = anova_tbl.loc["Residual", "df"]
        ms_within = ss_within / df_within
        ss_total = ss_between + ss_within

        omega2 = omega_squared_from_anova(
            ss_between, ss_within, df_between, ms_within, ss_total
        )

        results.append({
            "variable": var,
            "welch_F": float(welch.statistic),
            "welch_p": float(welch.pvalue),
            "kruskal_H": float(kruskal.statistic),
            "kruskal_p": float(kruskal.pvalue),
            "omega_squared_approx": omega2
        })

        # Comparaciones por pares (Welch t-test + Holm)
        pair_tests = []
        for i in range(len(models)):
            for j in range(i+1, len(models)):
                m1, m2 = models[i], models[j]
                x = df.loc[df["model"] == m1, var].values
                y = df.loc[df["model"] == m2, var].values
                t, p = stats.ttest_ind(x, y, equal_var=False)
                g = hedges_g(x, y)
                pair_tests.append((m1, m2, float(t), float(p), g))

        pvals = np.array([pt[3] for pt in pair_tests])
        order = np.argsort(pvals)
        holm_adj = np.empty_like(pvals)
        mtests = len(pvals)

        for k, idx in enumerate(order):
            holm_adj[idx] = min((mtests - k) * pvals[idx], 1.0)

        for (m1, m2, t, p, g), padj in zip(pair_tests, holm_adj):
            posthoc_rows.append({
                "variable": var,
                "group1": m1,
                "group2": m2,
                "welch_t": t,
                "p_raw": p,
                "p_holm": float(padj),
                "hedges_g": g
            })

    res_df = pd.DataFrame(results).sort_values("welch_p")
    post_df = pd.DataFrame(posthoc_rows).sort_values(["variable", "p_holm"])
    assume_df = pd.DataFrame(assume_rows)

    res_df.to_csv(f"{OUT_DIR}/tabla3_univariado_welch_kruskal.csv", index=False)
    post_df.to_csv(f"{OUT_DIR}/tabla4_posthoc_welch_holm_hedgesg.csv", index=False)
    assume_df.to_csv(f"{OUT_DIR}/tabla2_supuestos_shapiro_levene.csv", index=False)

def run_descriptives(df):
    desc = df.groupby("model")[ALL_VARS].agg(["mean", "std"])
    desc.to_csv(f"{OUT_DIR}/tabla1_descriptivos.csv")


def run_manova(df):
    from statsmodels.multivariate.manova import MANOVA

    man = MANOVA.from_formula(
        "ccomp_per1000 + xcomp_per1000 + advcl_per1000 + acl_relcl_per1000 + cc_per1000 + conj_per1000 ~ model",
        data=df
    )

    test = man.mv_test()

    with open(f"{OUT_DIR}/tabla5_manova.txt", "w", encoding="utf-8") as f:
        f.write(str(test))

def run_pca_and_plot(df):

    X = df[SYN_VARS].values
    y = df["model"].values

    scaler = StandardScaler()
    Xz = scaler.fit_transform(X)

    pca = PCA(n_components=2, random_state=0)
    pcs = pca.fit_transform(Xz)

    # Guardar loadings
    loadings = pd.DataFrame({
        "variable": SYN_VARS,
        "PC1_loading": pca.components_[0],
        "PC2_loading": pca.components_[1]
    })

    loadings.to_csv(f"{OUT_DIR}/tabla6_pca_loadings.csv", index=False)

    # Figura PCA
    plt.figure()

    for m in np.unique(y):
        idx = (y == m)
        plt.scatter(pcs[idx, 0], pcs[idx, 1], label=m)

    plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
    plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/figura1_pca.png", dpi=200)
    plt.close()


def run_lda(df):

    X = df[SYN_VARS].values
    y = df["model"].values

    scaler = StandardScaler()
    Xz = scaler.fit_transform(X)

    lda = LinearDiscriminantAnalysis()

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
    y_pred = cross_val_predict(lda, Xz, y, cv=cv)

    acc = accuracy_score(y, y_pred)
    cm = confusion_matrix(y, y_pred, labels=np.unique(y))

    cm_df = pd.DataFrame(
        cm,
        index=[f"true_{m}" for m in np.unique(y)],
        columns=[f"pred_{m}" for m in np.unique(y)]
    )

    cm_df.to_csv(f"{OUT_DIR}/tabla7_lda_confusion_matrix.csv")

    with open(f"{OUT_DIR}/tabla7_lda_accuracy.txt", "w") as f:
        f.write(f"5-fold CV accuracy: {acc:.3f}\n")


def effect_plot(df):
    # Figura 2: tamaños de efecto (ω²) de las variables sintácticas
    res = pd.read_csv(f"{OUT_DIR}/tabla3_univariado_welch_kruskal.csv")
    res = res[res["variable"].isin(SYN_VARS)].copy()
    res = res.sort_values("omega_squared_approx")

    plt.figure()
    plt.barh(res["variable"], res["omega_squared_approx"])
    plt.xlabel("ω² aproximado (efecto global)")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/figura2_efectos_omega2.png", dpi=200)
    plt.close()


def permanova_manual(df, n_perm=10000, seed=0):
    """
    PERMANOVA manual por permutación sobre distancias euclídeas en el espacio estandarizado.
    - Usa SYN_VARS (perfil sintáctico) como variables multivariadas.
    - Estadístico: pseudo-F de Anderson (2001) sobre matriz de distancias.
    """
    rng = np.random.default_rng(seed)

    # Matriz de datos (estandarizada)
    X = df[SYN_VARS].to_numpy(dtype=float)
    y = df["model"].to_numpy()

    scaler = StandardScaler()
    Xz = scaler.fit_transform(X)

    # Distancias euclídeas cuadradas (matriz NxN)
    G = Xz @ Xz.T
    sq_norms = np.diag(G)
    D2 = sq_norms[:, None] - 2 * G + sq_norms[None, :]
    D2[D2 < 0] = 0.0  # por errores numéricos

    n = D2.shape[0]
    labels = np.unique(y)
    k = len(labels)

    # SST basado en distancias
    SST = D2.sum() / n

    def ssw(groups):
        ssw_val = 0.0
        for lab in labels:
            idx = np.where(groups == lab)[0]
            m = len(idx)
            if m <= 1:
                continue
            ssw_val += D2[np.ix_(idx, idx)].sum() / m
        return ssw_val

    # Observado
    SSW_obs = ssw(y)
    SSB_obs = SST - SSW_obs

    df_between = k - 1
    df_within = n - k

    F_obs = (SSB_obs / df_between) / (SSW_obs / df_within)

    # Permutaciones
    count = 0
    for _ in range(n_perm):
        y_perm = rng.permutation(y)
        SSW_p = ssw(y_perm)
        SSB_p = SST - SSW_p
        F_p = (SSB_p / df_between) / (SSW_p / df_within)
        if F_p >= F_obs:
            count += 1

    p_perm = (count + 1) / (n_perm + 1)

    # Tamaño de efecto tipo R²
    R2 = SSB_obs / SST

    out = pd.DataFrame([{
        "F_obs": float(F_obs),
        "p_perm": float(p_perm),
        "R2": float(R2),
        "n_perm": int(n_perm),
        "seed": int(seed)
    }])

    out.to_csv(f"{OUT_DIR}/tabla8_permanova_manual.csv", index=False, encoding="utf-8")

def centroid_distances_bootstrap(df, n_boot=5000, seed=2026):
    """
    Calcula distancias Mahalanobis entre centroides de modelos
    con intervalos bootstrap.
    """

    rng = np.random.default_rng(seed)

    X = df[SYN_VARS].to_numpy(dtype=float)
    y = df["model"].to_numpy()

    scaler = StandardScaler()
    Xz = scaler.fit_transform(X)

    labels = np.unique(y)

    # Matriz de covarianza global (para Mahalanobis)
    cov = np.cov(Xz, rowvar=False)
    cov_inv = np.linalg.pinv(cov)

    # Centroides observados
    centroids = {}
    for lab in labels:
        centroids[lab] = Xz[y == lab].mean(axis=0)

    def mahalanobis(a, b):
        diff = a - b
        return float(np.sqrt(diff.T @ cov_inv @ diff))

    rows = []

    # Distancias observadas + bootstrap
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            m1, m2 = labels[i], labels[j]
            d_obs = mahalanobis(centroids[m1], centroids[m2])

            # Bootstrap (re-muestreo dentro de cada grupo)
            d_boot = []
            pool1 = np.where(y == m1)[0]
            pool2 = np.where(y == m2)[0]

            for _ in range(n_boot):
                idx1 = rng.choice(pool1, size=len(pool1), replace=True)
                idx2 = rng.choice(pool2, size=len(pool2), replace=True)

                c1 = Xz[idx1].mean(axis=0)
                c2 = Xz[idx2].mean(axis=0)

                d_boot.append(mahalanobis(c1, c2))

            ci_low = float(np.percentile(d_boot, 2.5))
            ci_high = float(np.percentile(d_boot, 97.5))

            rows.append({
                "model_1": str(m1),
                "model_2": str(m2),
                "distance_mahalanobis": float(d_obs),
                "ci_2.5": ci_low,
                "ci_97.5": ci_high,
                "n_boot": int(n_boot),
                "seed": int(seed)
            })

    out = pd.DataFrame(rows)
    out.to_csv(f"{OUT_DIR}/tabla9_distancias_centroides_bootstrap.csv",
               index=False, encoding="utf-8")

def plot_centroid_distance_heatmap():
    """
    Figura 3: heatmap de distancias Mahalanobis entre centroides sintácticos.
    Lee la tabla 9 y dibuja una matriz simétrica con anotaciones numéricas.
    """
    df = pd.read_csv(f"{OUT_DIR}/tabla9_distancias_centroides_bootstrap.csv")

    models = sorted(set(df["model_1"]).union(set(df["model_2"])))

    # Matriz simétrica de distancias
    mat = pd.DataFrame(np.nan, index=models, columns=models)
    for m in models:
        mat.loc[m, m] = 0.0

    for _, row in df.iterrows():
        m1 = row["model_1"]
        m2 = row["model_2"]
        d = float(row["distance_mahalanobis"])
        mat.loc[m1, m2] = d
        mat.loc[m2, m1] = d

    plt.figure()
    im = plt.imshow(mat.values, interpolation="nearest")
    plt.colorbar(im)

    plt.xticks(range(len(models)), models)
    plt.yticks(range(len(models)), models)

    # Anotar valores
    for i in range(len(models)):
        for j in range(len(models)):
            if i != j:
                plt.text(j, i, f"{mat.iloc[i, j]:.2f}", ha="center", va="center")

    plt.title("Distancias Mahalanobis entre centroides sintácticos")
    plt.tight_layout()
    plt.savefig(f"{OUT_DIR}/figura3_heatmap_distancias.png", dpi=200)
    plt.close()

def main():
    ensure_outdir()

    df = pd.read_csv(CSV_PATH)
    df["model"] = df["model"].astype(str)

    run_descriptives(df)
    run_univariate(df)
    run_manova(df)
    run_pca_and_plot(df)
    run_lda(df)
    effect_plot(df)

    # NUEVO: PERMANOVA por permutación
    permanova_manual(df, n_perm=10000, seed=2026)
    centroid_distances_bootstrap(df, n_boot=5000, seed=2026)
    plot_centroid_distance_heatmap()
    print("Análisis completado.")
    print(f"Revisa la carpeta: {OUT_DIR}")


if __name__ == "__main__":
    main()